from resources.model import SimpleCNN
from resources.adapter import AdaptiveClientHyperparams
from resources.fed_runner import (
    client_train_with_temp,
    aggregate_models,
    evaluate_model,
)
from resources.heterogeneity import (
    calculate_global_distribution,
    calculate_heterogeneity_score,
)

from utils.loader import load_and_partition_data
from utils.utils import init_seeds, get_args

import torch.optim as optim
import torch
import copy


def run_fedchill(args, device):

    print_flag = False

    (
        client_train_loaders,
        client_val_loaders,
        test_loader,
        client_class_distributions,
    ) = load_and_partition_data(
        num_clients=args.num_clients,
        alpha=args.alpha,
        batch_size=args.batch_size,
        frac=args.frac,
        rand_seed=42,
        dataset=args.dataset,
    )

    NUM_CLASSES = args.num_classes
    COMM_ROUND = args.comm_rounds
    LOCAL_EPOCHS = args.local_epochs

    num_clients = len(client_train_loaders)

    global_distribution = calculate_global_distribution(
        client_class_distributions, NUM_CLASSES
    )
    client_params = [AdaptiveClientHyperparams() for _ in range(num_clients)]

    # Calculate heterogeneity score and initialize temperatures
    print("\nInitializing client temperatures based on heterogeneity scores:\n")
    for i, distribution in enumerate(client_class_distributions):
        het_score = calculate_heterogeneity_score(
            distribution, global_distribution, NUM_CLASSES
        )
        client_params[i].update_temp_from_heterogeneity(het_score)
        print(
            f"Client {i} - Het Score: {het_score:.4f}, Initial temp: {client_params[i].get_temp():.4f}"
        )

    global_model = SimpleCNN(num_classes=NUM_CLASSES).to(device)
    round_server_acc_list = []
    round_client_acc_list = [[] for _ in range(num_clients)]

    client_temp_history = [[] for _ in range(num_clients)]
    client_lr_history = [[] for _ in range(num_clients)]

    client_models = [copy.deepcopy(global_model).to(device) for _ in range(num_clients)]
    client_optimizers = [
        optim.SGD(model.parameters(), lr=client_params[i].get_lr())
        for i, model in enumerate(client_models)
    ]

    for comm_round in range(COMM_ROUND):
        print(f"\n{'='*20} COMMUNICATION ROUND {comm_round+1} {'='*20}")

        selected_clients = list(range(num_clients))

        # Client update with personalized temperature scaling
        if comm_round == COMM_ROUND - 1:
            print_flag = True
            print("\n--- LOCAL TRAINING OF CLIENTS ---")

        for client_idx in selected_clients:

            client_models[client_idx].load_state_dict(global_model.state_dict())

            curr_lr = client_params[client_idx].get_lr()
            curr_temp = client_params[client_idx].get_temp()

            client_temp_history[client_idx].append(curr_temp)
            client_lr_history[client_idx].append(curr_lr)

            client_optimizers[client_idx] = optim.SGD(
                client_models[client_idx].parameters(), lr=curr_lr
            )

            if comm_round == COMM_ROUND - 1:
                print(
                    f"\n--- Client {client_idx} Local Training with T={curr_temp:.4f}, LR={curr_lr:.6f} ---"
                )

            client_train_with_temp(
                client_models[client_idx],
                client_train_loaders[client_idx],
                client_optimizers[client_idx],
                LOCAL_EPOCHS,
                device,
                temperature=curr_temp,
                print_flag=print_flag,
            )

            local_acc = evaluate_model(
                client_models[client_idx], client_val_loaders[client_idx], device
            )
            if comm_round == COMM_ROUND - 1:
                print(
                    f"Client {client_idx} Private Data Validation Accuracy: {local_acc:.2f}%"
                )

            client_testset_acc = evaluate_model(
                client_models[client_idx], test_loader, device
            )
            round_client_acc_list[client_idx].append(client_testset_acc)

            adjusted, param_type, new_value = client_params[
                client_idx
            ].record_performance(local_acc)
            if adjusted and (comm_round == COMM_ROUND - 1 or print_flag):
                print(f"Client {client_idx} - {param_type} to {new_value:.6f}")

        aggregate_models(
            global_model, client_models, selected_clients, client_train_loaders
        )

        server_test_val_acc = evaluate_model(global_model, test_loader, device)
        if comm_round == COMM_ROUND - 1:
            print(f"\n--- Server Evaluation ---")
            print(f"Server Test Data Validation Accuracy: {server_test_val_acc:.2f}%")

        round_server_acc_list.append(server_test_val_acc)

    return (
        round_server_acc_list,
        round_client_acc_list,
        client_temp_history,
        client_lr_history,
        global_model,
    )


if __name__ == "__main__":

    init_seeds()
    args = get_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    print(f"Running FedChill with arguments: \n {args} \n Device: {device}\n")
    run_fedchill(args, device)
